{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Training an MNIST Classifier\n",
    "=====\n",
    "## Custom Dataset, Model Checkpointing, and Fine-tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import glob\n",
    "import os.path as osp\n",
    "import numpy as np\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Custom Dataset\n",
    "PyTorch has many built-in datasets such as MNIST and CIFAR. In this tutorial, we demonstrate how to write your own dataset by implementing a custom MNIST dataset class. Use [this link](https://github.com/myleott/mnist_png/blob/master/mnist_png.tar.gz?raw=true) to download the mnist png dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MNIST(Dataset):\n",
    "    \"\"\"\n",
    "    A customized data loader for MNIST.\n",
    "    \"\"\"\n",
    "    def __init__(self,\n",
    "                 root,\n",
    "                 transform=None,\n",
    "                 preload=False):\n",
    "        \"\"\" Intialize the MNIST dataset\n",
    "        \n",
    "        Args:\n",
    "            - root: root directory of the dataset\n",
    "            - tranform: a custom tranform function\n",
    "            - preload: if preload the dataset into memory\n",
    "        \"\"\"\n",
    "        self.images = None\n",
    "        self.labels = None\n",
    "        self.filenames = []\n",
    "        self.root = root\n",
    "        self.transform = transform\n",
    "\n",
    "        # read filenames\n",
    "        for i in range(10):\n",
    "            filenames = glob.glob(osp.join(root, str(i), '*.png'))\n",
    "            for fn in filenames:\n",
    "                self.filenames.append((fn, i)) # (filename, label) pair\n",
    "                \n",
    "        # if preload dataset into memory\n",
    "        if preload:\n",
    "            self._preload()\n",
    "            \n",
    "        self.len = len(self.filenames)\n",
    "                              \n",
    "    def _preload(self):\n",
    "        \"\"\"\n",
    "        Preload dataset to memory\n",
    "        \"\"\"\n",
    "        self.labels = []\n",
    "        self.images = []\n",
    "        for image_fn, label in self.filenames:            \n",
    "            # load images\n",
    "            image = Image.open(image_fn)\n",
    "            # avoid too many opened files bug\n",
    "            self.images.append(image.copy())\n",
    "            image.close()\n",
    "            self.labels.append(label)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        \"\"\" Get a sample from the dataset\n",
    "        \"\"\"\n",
    "        if self.images is not None:\n",
    "            # If dataset is preloaded\n",
    "            image = self.images[index]\n",
    "            label = self.labels[index]\n",
    "        else:\n",
    "            # If on-demand data loading\n",
    "            image_fn, label = self.filenames[index]\n",
    "            image = Image.open(image_fn)\n",
    "            \n",
    "        # May use transform function to transform samples\n",
    "        # e.g., random crop, whitening\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        # return image and label\n",
    "        return image, label\n",
    "\n",
    "    def __len__(self):\n",
    "        \"\"\"\n",
    "        Total number of samples in the dataset\n",
    "        \"\"\"\n",
    "        return self.len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the MNIST dataset. \n",
    "# transforms.ToTensor() automatically converts PIL images to\n",
    "# torch tensors with range [0, 1]\n",
    "trainset = MNIST(\n",
    "    root='mnist_png/training',\n",
    "    preload=True, transform=transforms.ToTensor(),\n",
    ")\n",
    "# Use the torch dataloader to iterate through the dataset\n",
    "trainset_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)\n",
    "\n",
    "# load the testset\n",
    "testset = MNIST(\n",
    "    root='mnist_png/testing',\n",
    "    preload=True, transform=transforms.ToTensor(),\n",
    ")\n",
    "# Use the torch dataloader to iterate through the dataset\n",
    "testset_loader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(trainset))\n",
    "print(len(testset))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# functions to show an image\n",
    "def imshow(img):\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "\n",
    "# get some random training images\n",
    "dataiter = iter(trainset_loader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# show images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "# print labels\n",
    "print(' '.join('%5s' % labels[j] for j in range(16)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use GPU if available, otherwise stick with cpu\n",
    "use_cuda = torch.cuda.is_available()\n",
    "torch.manual_seed(123)\n",
    "device = torch.device(cuda if use_cuda else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a Conv Net\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 50)\n",
    "        self.fc2 = nn.Linear(50, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.fc2(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n",
    "model = Net().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch, log_interval=100):\n",
    "    model.train()  # set training mode\n",
    "    iteration = 0\n",
    "    for ep in range(epoch):\n",
    "        for batch_idx, (data, target) in enumerate(trainset_loader):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = F.nll_loss(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if iteration % log_interval == 0:\n",
    "                print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    ep, batch_idx * len(data), len(trainset_loader.dataset),\n",
    "                    100. * batch_idx / len(trainset_loader), loss.item()))\n",
    "            iteration += 1\n",
    "        test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()  # set evaluation mode\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in testset_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.nll_loss(output, target, size_average=False).item() # sum up batch loss\n",
    "            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "    test_loss /= len(testset_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "        test_loss, correct, len(testset_loader.dataset),\n",
    "        100. * correct / len(testset_loader.dataset)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(5)  # train 5 epochs should get you to about 97% accuracy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Save the model (model checkpointing)\n",
    "\n",
    "Now we have a model! Obviously we do not want to retrain the model everytime we want to use it. Plus if you are training a super big model, you probably want to save checkpoint periodically so that you can always fall back to the last checkpoint in case something bad happened or you simply want to test models at different training iterations.\n",
    "\n",
    "Model checkpointing is fairly simple in PyTorch. First, we define a helper function that can save a model to the disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_checkpoint(checkpoint_path, model, optimizer):\n",
    "    state = {'state_dict': model.state_dict(),\n",
    "             'optimizer' : optimizer.state_dict()}\n",
    "    torch.save(state, checkpoint_path)\n",
    "    print('model saved to %s' % checkpoint_path)\n",
    "    \n",
    "def load_checkpoint(checkpoint_path, model, optimizer):\n",
    "    state = torch.load(checkpoint_path)\n",
    "    model.load_state_dict(state['state_dict'])\n",
    "    optimizer.load_state_dict(state['optimizer'])\n",
    "    print('model loaded from %s' % checkpoint_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a brand new model\n",
    "model = Net().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a training loop with model checkpointing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_save(epoch, save_interval, log_interval=100):\n",
    "    model.train()  # set training mode\n",
    "    iteration = 0\n",
    "    for ep in range(epoch):\n",
    "        for batch_idx, (data, target) in enumerate(trainset_loader):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = F.nll_loss(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if iteration % log_interval == 0:\n",
    "                print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    ep, batch_idx * len(data), len(trainset_loader.dataset),\n",
    "                    100. * batch_idx / len(trainset_loader), loss.item()))\n",
    "            if iteration % save_interval == 0 and iteration > 0:\n",
    "                save_checkpoint('mnist-%i.pth' % iteration, model, optimizer)\n",
    "            iteration += 1\n",
    "        test()\n",
    "    \n",
    "    # save the final model\n",
    "    save_checkpoint('mnist-%i.pth' % iteration, model, optimizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_save(5, 500, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a new model\n",
    "model = Net().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n",
    "# load from the final checkpoint\n",
    "load_checkpoint('mnist-4690.pth', model, optimizer)\n",
    "# should give you the final model accuracy\n",
    "test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Fine-tune a model\n",
    "\n",
    "Sometimes you want to fine-tune a pretrained model instead of training a model from scratch. For example, if you want to train a model on a new dataset that contains natural images. To achieve the best performance, you can start with a model that's fully trained on ImageNet and fine-tune the model.\n",
    "\n",
    "Finetuning a model in PyTorch is super easy! First, let's find out what we saved in a checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What's in a state dict?\n",
    "print(model.state_dict().keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune the fc layers\n",
    "\n",
    "Now say we want to load the conv layers from the checkpoint and train the fc layers. We can simply load a subset of the state dict with the selected names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('mnist-4690.pth')\n",
    "states_to_load = {}\n",
    "for name, param in checkpoint['state_dict'].items():\n",
    "    if name.startswith('conv'):\n",
    "        states_to_load[name] = param\n",
    "\n",
    "# Construct a new state dict in which the layers we want\n",
    "# to import from the checkpoint is update with the parameters\n",
    "# from the checkpoint\n",
    "model_state = model.state_dict()\n",
    "model_state.update(states_to_load)\n",
    "        \n",
    "model = Net().to(device)\n",
    "model.load_state_dict(model_state)\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(1)  # training 1 epoch will get you to 93%!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import pretrained weights in a different model\n",
    "\n",
    "We can even use the pretrained conv layers in a different model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SmallNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SmallNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n",
    "        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n",
    "        self.conv2_drop = nn.Dropout2d()\n",
    "        self.fc1 = nn.Linear(320, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        x = x.view(-1, 320)\n",
    "        x = self.fc1(x)\n",
    "        return F.log_softmax(x, dim=1)\n",
    "\n",
    "model = SmallNet().to(device)\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint = torch.load('mnist-4690.pth')\n",
    "states_to_load = {}\n",
    "for name, param in checkpoint['state_dict'].items():\n",
    "    if name.startswith('conv'):\n",
    "        states_to_load[name] = param\n",
    "\n",
    "# Construct a new state dict in which the layers we want\n",
    "# to import from the checkpoint is update with the parameters\n",
    "# from the checkpoint\n",
    "model_state = model.state_dict()\n",
    "model_state.update(states_to_load)\n",
    "        \n",
    "model.load_state_dict(model_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(1)  # training 1 epoch will get you to 93%!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
